import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import gurobipy as gp
from gurobipy import GRB
import os
# from xbcausalforest import XBCF
from DataGeneration.DataGeneration import GenerateData
from DataGeneration.PosteriorGeneration import ComputePosterior
import stan
import sqlite3
from joblib import Parallel, delayed

X_evaluation = np.array([np.meshgrid(np.linspace(-1, 1, 100), np.linspace(-1, 1, 100))[0].reshape((-1)),
                         np.meshgrid(np.linspace(-1, 1, 100), np.linspace(-1, 1, 100))[1].reshape((-1))]).swapaxes(0, 1)

maketable = lambda t1,t2,t3: np.stack([t1,t2,t3])

def Gub_constraint(Existing_policy, EA, NIP, N, X, epsilon):
    m = gp.Model('conditional')
    m.Params.LogToConsole = 0
    b = [m.addMVar(1, lb=0.001, ub=1000, vtype=GRB.CONTINUOUS), m.addMVar(1, lb=0.001, ub=1000, vtype=GRB.CONTINUOUS)]
    y = m.addMVar(N, vtype=GRB.BINARY)
    m.addConstrs(y[i] * (b[1] * X[i, 1] + b[0] * X[i, 0] - 1) >= 0 for i in range(N))
    m.addConstrs((1 - y[i]) * (b[1] * X[i, 1] + b[0] * X[i, 0] - 1) <= 0 for i in range(N))
    EPI = y @ EA - Existing_policy @ EA
    Meat = np.diag(NIP) / N
    Risk = y @ Meat @ y - 2 * Existing_policy @ Meat @ y + Existing_policy @ Meat @ Existing_policy
    m.setObjective(EPI, GRB.MAXIMIZE)
    m.addConstr(Risk <= epsilon, name='risk')
    y.Start = Existing_policy
    m.optimize()
    return m

def Simulate_con(myseed=None,sigma=None, sigmakernel=None):
    k = 2
    Optimal_policy = lambda X, eyx: 1 * (eyx(X, 1) > eyx(X, 0))
    DG = GenerateData(Ydist='normal', eyx=eyx1, sigma=sigma, collecting_rule=collecting_rule, k=2, N=N, seed=myseed)
    X, A, Ey, Y = DG.GenerateData()
    paramlist = {'length_scale': length_scale, 'sigma_kernel': sigmakernel}
    EPI, NIP = ComputePosterior(X, A, Y, paramlist, trt_price=0)
    optpolicy = lambda X: Optimal_policy(X, eyx1)
    m_all = [Gub_constraint(basepolicy(X), EPI, NIP, N, X, epsilon) for epsilon in epsilon_all]
    newpolicy_crust = lambda X, m: 1 * ((m.X[0] * X[:, 0] + m.X[1] * X[:, 1] - 1) >= 0)
    NIP_all = [np.logical_and(optpolicy(X_evaluation) == basepolicy(X_evaluation),
                              optpolicy(X_evaluation) != newpolicy_crust(X_evaluation, m)).mean() for m in m_all]
    AUI_all = [eyx1(X_evaluation, newpolicy_crust(X_evaluation, m)).mean() for m in m_all]
    return np.array([NIP_all, AUI_all])


mysigmalist = [1, 3]
myNlist = [200]
mylengthlist = [0.5, 1,2]
mysigmakernel = [1,2,4]
for N in myNlist:
    for sigma in mysigmalist:
        for length_scale in mylengthlist:
            for sigmakernel in mysigmakernel:
                mydataname = "_".join(["GP", str(N), str(sigma), str(length_scale),str(sigmakernel), '.npy'])
                if os.path.exists('Data/Continuous/' + mydataname):
                    continue
                epsilon_all = np.array([1e-3, 2e-3, 3.3e-3, 5e-3, 6.7e-3, 8e-3, 1e-2, 2e-2, 3.3e-2, 5e-2, 6.7e-2, 8e-2,
                                        1e-1, 0.125, 0.15, 0.175, 2e-1, 3.3e-1, 6.7e-1, 8e-1, 1])
                eyx1 = lambda X, A: (+X[:, 0] + X[:, 1] + 4 * A * (X[:, 0].__abs__()) * (X[:, 1].__abs__()) * (
                            (X[:, 1] > 0) * (X[:, 0] > 0) - 0.5))
                collecting_rule = lambda X: 1 * (X[:, 0] > 0.5)  # return the probability of having action 1
                basepolicy = collecting_rule
                temp3 = Parallel(n_jobs=10)(delayed(Simulate_con)(sigma=sigma, sigmakernel = sigmakernel) for i in range(5))
                result3 = np.array(temp3)
                mydataname = "_".join(["GP", str(N), str(sigma),str(length_scale),str(sigmakernel), '.npy'])
                np.save('Data/Continuous/' + mydataname, result3)
